import os
import pickle

import numpy as np
# from txt2graph import txt2generalgraph
from causallearn.utils.TXT2GeneralGraph import txt2generalgraph


def sachs_init():
    data_root = os.path.join("sachs"+".txt")
    with open(data_root, 'r', encoding='utf-8') as f:
        variable_name = []
        variable = np.array([])
        i = 0
        for ann in f.readlines():
            if i == 0:
                variable_name = ann.split(' ')
                variable_name[-1] = variable_name[-1][:-1]
            else:
                one_data = ann.split(' ')
                one_data[-1] = one_data[-1][:-1]
                one_data = np.array([int(x) for x in one_data]).reshape(1, -1)
                if i == 1:
                    variable = one_data
                else:
                    variable = np.concatenate((variable, one_data), axis=0)
            i += 1

    Data = {}
    Data['data_mat'] = variable.astype(np.float64)
    Data['data_name'] = "sachs"
    Data['threshold'] = 0.8
    Data['width_init'] = 0.001
    Data['var_name'] = variable_name
    Data['var_idx'] = np.array([[i] for i in range(len(variable_name))])
    Data['G'] = graph_from_txt()

    with open("sachs.pkl", "wb") as tf:
        pickle.dump(Data, tf)

    return Data

def graph_from_txt():
    truth_dag = txt2generalgraph("sachs.graph.txt")
    # truth_cpdag = dag2cpdag(truth_dag)
    # print(truth_cpdag.graph)
    return truth_dag.graph

def graph():
    data_root = os.path.join("sachs.graph"+".txt")
    with open(data_root, 'r', encoding='utf-8') as f:
        G = np.array([])
        i = 0
        for ann in f.readlines():
            if i == 0:
                variable_name = ann.split('\t')
                variable_name[-1] = variable_name[-1][:-1]
            else:
                one_data = ann.split('\t')
                one_data[-1] = one_data[-1][:-1]
                one_data = np.array([int(x) for x in one_data]).reshape(1, -1)
                if i == 1:
                    G = one_data
                else:
                    G = np.concatenate((G, one_data), axis=0)
            i += 1

        n = G.shape[0]
        for i in range(n):
            for j in range(n):
                if G[i][j] == -1:
                    G[j][i] = 1

    return G

def sachs_load(nums, seeds):
    np.random.seed(seeds)
    with open("Data/SACHS/sachs.pkl", "rb") as tf:
        Data_dict = pickle.load(tf)
    Data = Data_dict
    Varables = Data['data_mat']
    n = Varables.shape[0]
    idx = np.random.choice(n, nums, replace=False)
    sampled_data = Varables[idx]
    Data['data_mat'] = sampled_data
    return Data

if __name__ == '__main__':
    # print(graph())
    data_dir = sachs_init()
    # data_dir = sachs_load(100)
    # print(data_dir['G'])